{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8e4b4554",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PreTrainedTokenizerFast(name_or_path='hfl/rbt3', vocab_size=21128, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第6章/加载tokenizer\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained('hfl/rbt3')\n",
    "\n",
    "tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b3086c2f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'input_ids': [[101, 3209, 3299, 6163, 7652, 749, 872, 4638, 4970, 2094, 102], [101, 872, 6163, 7652, 749, 1166, 782, 4638, 3457, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第6章/试编码句子\n",
    "tokenizer.batch_encode_plus(\n",
    "    ['明月装饰了你的窗子', '你装饰了别人的梦'],\n",
    "    truncation=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bdd8341a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 2000\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 0\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 100\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第6章/从磁盘加载数据集\n",
    "from datasets import load_from_disk\n",
    "\n",
    "dataset = load_from_disk('./data/ChnSentiCorp')\n",
    "\n",
    "#缩小数据规模，便于测试\n",
    "dataset['train'] = dataset['train'].shuffle().select(range(2000))\n",
    "dataset['test'] = dataset['test'].shuffle().select(range(100))\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "996c8e5a",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "        "
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7196c0cd9bb445f1a1e52988da589323",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "#0:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fd4666a61e31469987b912abfa9c739b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "#1:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "be3adf274e64461fb44f91570f8b69a7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "#3:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "189255366e5f4aac9bf277c783267fc7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "#2:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "        "
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "90148dcf7bd346979c268c2d0aa75500",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "#0:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "83f021b8f0ab49768eb1a074f67fa726",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "#1:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e51a8a218fce4e75919977af933c11e1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "#3:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "325947e5d1394824a41b75d262df1e1b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "#2:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
       "        num_rows: 2000\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['label'],\n",
       "        num_rows: 0\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
       "        num_rows: 100\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第6章/编码\n",
    "def f(data):\n",
    "    return tokenizer.batch_encode_plus(data['text'], truncation=True)\n",
    "\n",
    "\n",
    "dataset = dataset.map(f,\n",
    "                      batched=True,\n",
    "                      batch_size=1000,\n",
    "                      num_proc=4,\n",
    "                      remove_columns=['text'])\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "196060e9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "        "
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d7270cc3aab547a2be431d8fe86fb248",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "#0:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f80d9083665d4dd2a01edc4b700c6a50",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "#1:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "838e134c269a4daba55e51249fe55f68",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "#2:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "59347230e1514d499089354aea1548b1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "#3:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "        "
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "22c491e9287346c8a9a454ebc8aa8c4d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "#0:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c7862d384d7f4330bb24751079f52f91",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "#2:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8e734a60bd714daaae9f7b55354bbaa9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "#1:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "013fa96c577c4af7b912e1d2cf3ccc1d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "#3:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
       "        num_rows: 1980\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['label'],\n",
       "        num_rows: 0\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
       "        num_rows: 99\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第6章/移除太长的句子\n",
    "def f(data):\n",
    "    return [len(i) <= 512 for i in data['input_ids']]\n",
    "\n",
    "\n",
    "dataset = dataset.filter(f, batched=True, batch_size=1000, num_proc=4)\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ad583ee8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at hfl/rbt3 were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at hfl/rbt3 and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "3847.8338"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第6章/加载模型\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained('hfl/rbt3',\n",
    "                                                           num_labels=2)\n",
    "\n",
    "#统计模型参数量\n",
    "sum([i.nelement() for i in model.parameters()]) / 10000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b5bde364",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.4744, grad_fn=<NllLossBackward>), torch.Size([4, 2]))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第6章/模型试算\n",
    "import torch\n",
    "\n",
    "#模拟一批数据\n",
    "data = {\n",
    "    'input_ids': torch.ones(4, 10, dtype=torch.long),\n",
    "    'token_type_ids': torch.ones(4, 10, dtype=torch.long),\n",
    "    'attention_mask': torch.ones(4, 10, dtype=torch.long),\n",
    "    'labels': torch.ones(4, dtype=torch.long)\n",
    "}\n",
    "\n",
    "#模型试算\n",
    "out = model(**data)\n",
    "\n",
    "out['loss'], out['logits'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8dec01a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# #第6章/加载评价指标\n",
    "# from datasets import load_metric\n",
    "\n",
    "# metric = load_metric('accuracy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "403170ac",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'accuracy': 0.75}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第6章/定义评价函数\n",
    "import numpy as np\n",
    "from transformers.trainer_utils import EvalPrediction\n",
    "\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    logits, labels = eval_pred\n",
    "    logits = logits.argmax(axis=1)\n",
    "    return {'accuracy': (logits == labels).sum() / len(labels)}\n",
    "    #return metric.compute(predictions=logits, references=labels)\n",
    "\n",
    "\n",
    "#模拟输出\n",
    "eval_pred = EvalPrediction(\n",
    "    predictions=np.array([[0, 1], [2, 3], [4, 5], [6, 7]]),\n",
    "    label_ids=np.array([1, 1, 0, 1]),\n",
    ")\n",
    "\n",
    "compute_metrics(eval_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d86c46ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "#第6章/定义训练参数\n",
    "from transformers import TrainingArguments\n",
    "\n",
    "#定义训练参数\n",
    "args = TrainingArguments(\n",
    "    #定义临时数据保存路径\n",
    "    output_dir='./output_dir',\n",
    "\n",
    "    #定义测试执行的策略，可取值no、epoch、steps\n",
    "    evaluation_strategy='steps',\n",
    "\n",
    "    #定义每隔多少个step执行一次测试\n",
    "    eval_steps=30,\n",
    "\n",
    "    #定义模型保存策略，可取值no、epoch、steps\n",
    "    save_strategy='steps',\n",
    "\n",
    "    #定义每隔多少个step保存一次\n",
    "    save_steps=30,\n",
    "\n",
    "    #定义共训练几个轮次\n",
    "    num_train_epochs=1,\n",
    "\n",
    "    #定义学习率\n",
    "    learning_rate=1e-4,\n",
    "\n",
    "    #加入参数权重衰减，防止过拟合\n",
    "    weight_decay=1e-2,\n",
    "\n",
    "    #定义测试和训练时的批次大小\n",
    "    per_device_eval_batch_size=16,\n",
    "    per_device_train_batch_size=16,\n",
    "\n",
    "    #定义是否要使用gpu训练\n",
    "    no_cuda=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5a5748cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "#第6章/定义训练器\n",
    "from transformers import Trainer\n",
    "from transformers.data.data_collator import DataCollatorWithPadding\n",
    "\n",
    "#定义训练器\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=args,\n",
    "    train_dataset=dataset['train'],\n",
    "    eval_dataset=dataset['test'],\n",
    "    compute_metrics=compute_metrics,\n",
    "    data_collator=DataCollatorWithPadding(tokenizer),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3ab1ee63",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "71\n",
      "53\n",
      "75\n",
      "42\n",
      "181\n",
      "input_ids torch.Size([5, 181])\n",
      "token_type_ids torch.Size([5, 181])\n",
      "attention_mask torch.Size([5, 181])\n",
      "labels torch.Size([5])\n"
     ]
    }
   ],
   "source": [
    "#第6章/测试数据整理函数\n",
    "data_collator = DataCollatorWithPadding(tokenizer)\n",
    "\n",
    "#获取一批数据\n",
    "data = dataset['train'][:5]\n",
    "\n",
    "#输出这些句子的长度\n",
    "for i in data['input_ids']:\n",
    "    print(len(i))\n",
    "\n",
    "#调用数据整理函数\n",
    "data = data_collator(data)\n",
    "\n",
    "#查看整理后的数据\n",
    "for k, v in data.items():\n",
    "    print(k, v.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "4ac07ba4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'[CLS] 收 到 书 之 后 第 一 感 觉 就 是 脏 脏 的, 随 手 翻 了 几 页, 发 现 里 面 竟 有 破 损 的 纸 张, 我 倒! 有 的 页 面 上 还 有 漆 黑 的 脚 印, 当 当 的 工 作 人 员 发 书 时 难 道 不 确 认 么??? [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(data['input_ids'][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f7508e31",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "***** Running Evaluation *****\n",
      "  Num examples = 99\n",
      "  Batch size = 16\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='14' max='7' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [7/7 00:10]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'eval_loss': 0.716060996055603,\n",
       " 'eval_accuracy': 0.42424242424242425,\n",
       " 'eval_runtime': 0.5896,\n",
       " 'eval_samples_per_second': 167.899,\n",
       " 'eval_steps_per_second': 11.872}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第6章/评价模型\n",
    "trainer.evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "2861051f",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/gpu/lib/python3.6/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  FutureWarning,\n",
      "***** Running training *****\n",
      "  Num examples = 1980\n",
      "  Num Epochs = 1\n",
      "  Instantaneous batch size per device = 16\n",
      "  Total train batch size (w. parallel, distributed & accumulation) = 16\n",
      "  Gradient Accumulation steps = 1\n",
      "  Total optimization steps = 124\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='124' max='124' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [124/124 00:45, Epoch 1/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.368228</td>\n",
       "      <td>0.838384</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>60</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.342529</td>\n",
       "      <td>0.868687</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>90</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.452716</td>\n",
       "      <td>0.858586</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>120</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.307068</td>\n",
       "      <td>0.898990</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "***** Running Evaluation *****\n",
      "  Num examples = 99\n",
      "  Batch size = 16\n",
      "Saving model checkpoint to ./output_dir/checkpoint-30\n",
      "Configuration saved in ./output_dir/checkpoint-30/config.json\n",
      "Model weights saved in ./output_dir/checkpoint-30/pytorch_model.bin\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 99\n",
      "  Batch size = 16\n",
      "Saving model checkpoint to ./output_dir/checkpoint-60\n",
      "Configuration saved in ./output_dir/checkpoint-60/config.json\n",
      "Model weights saved in ./output_dir/checkpoint-60/pytorch_model.bin\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 99\n",
      "  Batch size = 16\n",
      "Saving model checkpoint to ./output_dir/checkpoint-90\n",
      "Configuration saved in ./output_dir/checkpoint-90/config.json\n",
      "Model weights saved in ./output_dir/checkpoint-90/pytorch_model.bin\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 99\n",
      "  Batch size = 16\n",
      "Saving model checkpoint to ./output_dir/checkpoint-120\n",
      "Configuration saved in ./output_dir/checkpoint-120/config.json\n",
      "Model weights saved in ./output_dir/checkpoint-120/pytorch_model.bin\n",
      "\n",
      "\n",
      "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=124, training_loss=0.39254542320005353, metrics={'train_runtime': 45.7338, 'train_samples_per_second': 43.294, 'train_steps_per_second': 2.711, 'total_flos': 70209086949120.0, 'train_loss': 0.39254542320005353, 'epoch': 1.0})"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第6章/训练\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "5971dd75",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading model from ./output_dir/checkpoint-90).\n",
      "***** Running training *****\n",
      "  Num examples = 1980\n",
      "  Num Epochs = 1\n",
      "  Instantaneous batch size per device = 16\n",
      "  Total train batch size (w. parallel, distributed & accumulation) = 16\n",
      "  Gradient Accumulation steps = 1\n",
      "  Total optimization steps = 124\n",
      "  Continuing training from checkpoint, will skip to saved global_step\n",
      "  Continuing training from epoch 0\n",
      "  Continuing training from global step 90\n",
      "  Will skip the first 0 epochs then the first 90 batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` flag to your launch command, but you will resume the training on data already seen by your model.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e45902458956449487b7457d4cc2ee52",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/90 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='124' max='124' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [124/124 00:12, Epoch 1/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>120</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.359170</td>\n",
       "      <td>0.868687</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "***** Running Evaluation *****\n",
      "  Num examples = 99\n",
      "  Batch size = 16\n",
      "Saving model checkpoint to ./output_dir/checkpoint-120\n",
      "Configuration saved in ./output_dir/checkpoint-120/config.json\n",
      "Model weights saved in ./output_dir/checkpoint-120/pytorch_model.bin\n",
      "\n",
      "\n",
      "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=124, training_loss=0.07213338728873961, metrics={'train_runtime': 13.1996, 'train_samples_per_second': 150.004, 'train_steps_per_second': 9.394, 'total_flos': 68529398541984.0, 'train_loss': 0.07213338728873961, 'epoch': 1.0})"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第6章/从某个存档继续训练\n",
    "trainer.train(resume_from_checkpoint='./output_dir/checkpoint-90')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "22dc37ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "***** Running Evaluation *****\n",
      "  Num examples = 99\n",
      "  Batch size = 16\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='7' max='7' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [7/7 00:00]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'eval_loss': 0.36910974979400635,\n",
       " 'eval_accuracy': 0.8484848484848485,\n",
       " 'eval_runtime': 0.5673,\n",
       " 'eval_samples_per_second': 174.517,\n",
       " 'eval_steps_per_second': 12.34,\n",
       " 'epoch': 1.0}"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第6章/评价模型\n",
    "trainer.evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "8d024896",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to ./output_dir/save_model\n",
      "Configuration saved in ./output_dir/save_model/config.json\n",
      "Model weights saved in ./output_dir/save_model/pytorch_model.bin\n"
     ]
    }
   ],
   "source": [
    "#第6章/手动保存模型参数\n",
    "trainer.save_model(output_dir='./output_dir/save_model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "c386463d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第6章/手动加载模型参数\n",
    "import torch\n",
    "\n",
    "model.load_state_dict(torch.load('./output_dir/save_model/pytorch_model.bin'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "c46f3818",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "原 3255 ， 300 卷 + 瑞 星 2009.. 实 际 价 : 2995. 比 较 实 惠 。 屏 比 较 大 哦 。 独 显 不 错 ， 一 般 够 用 了 。\n",
      "label= 1\n",
      "predict= 1\n",
      "服 务 很 差 ！ ！ ！ 服 务 员 的 态 度 很 不 好 ！ 环 境 也 一 般\n",
      "label= 0\n",
      "predict= 0\n",
      "不 足 就 是 没 有 配 套 光 盘 ， 需 要 自 己 备 份 。 我 不 太 喜 欢 这 种 触 摸 板 ， 没 有 送 鼠 标 ， 还 要 自 己 去 配 ， 这 个 都 是 小 问 题 。 无 线 网 卡 插 槽 的 按 钮 太 突 出 了 ， 很 怕 不 小 心 会 弄 断 。 不 过 使 用 不 要 那 么 野 蛮 呵 呵 就 不 会 有 大 问 题 。\n",
      "label= 0\n",
      "predict= 0\n",
      "设 施 很 不 错 ； 周 边 环 境 不 错 ， 紧 邻 金 沙 滩 ； 推 荐 海 景 房 不 足 ： 比 较 偏 ； 吃 饭 不 放 便 ； 酒 店 自 带 的 餐 厅 性 价 比 太 差 。 价 格 太 贵 ， 口 味 不 好 ；\n",
      "label= 1\n",
      "predict= 1\n",
      "大 家 都 说 这 本 书 写 了 很 多 心 理 案 例, 但 我 觉 得 什 么 案 例 呀 ， 不 就 是 作 者 编 的 故 事 么 ， 而 且 欠 分 析 。 说 什 么 是 心 理 学 的 辅 读 读 物 ， 我 看 就 是 把 一 些 故 事 串 起 来 ， 再 打 乱 一 下 顺 序 而 已 。 推 荐 那 些 喜 欢 看 小 说 的 人 看 ， 但 真 想 学 到 点 心 理 学 知 识 ， 那 就 免 谈 了 。\n",
      "label= 0\n",
      "predict= 0\n",
      "贾 志 刚 是 有 才 的 ， 所 以 写 了 春 秋 ， 但 人 嘛 就 向 春 秋 那 样 乱 ， 不 对 ， 说 错 了 ！ 应 该 说 ， 人 心 嘛 ， 就 像 春 秋 那 样 乱 。 所 以 《 春 秋 》 一 出 ， 赞 弹 皆 有 ， 其 实 贾 志 刚 说 得 对 ， 不 看 春 秋 ， 连 祖 宗 姓 什 么 都 不 知 道 呢 ！ 虽 然 书 中 句 子 略 有 不 雅 ， 就 是 俗 了 点 ， 但 俗 容 易 看 明 白 。 如 果 写 书 的 都 整 点 文 言 文 或 者 为 不 俗 ， 故 意 转 弯 没 角 ， 春 秋 不 用 写 了 ， 也 不 用 看 了 ， 祖 宗 也 不 用 理 了 ！ 所 以 春 秋 要 看 ， 不 管 他 写 得 俗 不 俗 。\n",
      "label= 1\n",
      "predict= 0\n",
      "dell 商 务 机 做 工 不 错 ， 外 观 设 计 比 以 前 的 商 务 机 好 很 多 了 。 价 格 又 便 宜 ， 这 不 错 。\n",
      "label= 1\n",
      "predict= 1\n",
      "差 劲 ， 并 没 有 特 殊 的 ， 我 订 了 2 天 的 房 ， 到 第 三 天 续 房 要 换 房 ， 结 果 忘 了 拿 抽 屉 里 的 dv 机 ， 第 二 天 ， 下 午 电 话 给 客 房 中 心 ， 说 给 回 复 情 况 ， 结 果 未 给 回 复 ； 打 给 大 堂 值 班 经 理 投 诉 并 要 求 协 助 ； 投 诉 理 由 ： 1. 不 知 四 还 是 五 星 的 宾 馆 换 房 怎 么 无 人 及 时 发 现 客 人 遗 留 物 品 ， 及 及 时 通 知 客 人 。 她 的 理 由 是 换 房 与 退 房 不 同 ， 房 间 第 二 天 早 上 才 有 服 务 员 进 去 打 扫 的 。 。 。 。 总 之 ， 什 么 凯 莱 ， 其 实 也 是 很 差 的 ， 不 要 对 他 寄 望 什 么 。\n",
      "label= 0\n",
      "predict= 0\n",
      "性 价 比 高 同 等 价 位 的 其 它 品 牌 配 置 没 这 么 高 的 （ 除 了 个 别 ） 外 观 不 错 做 工 较 好\n",
      "label= 1\n",
      "predict= 1\n",
      "送 的 包 比 较 普 通, 做 工 也 不 好, 和 商 城 上 99 块 那 款 东 芝 笔 记 本 包 是 一 样 的, 希 望 能 送 好 点 的 包\n",
      "label= 0\n",
      "predict= 0\n",
      "书 中 毛 主 席 的 每 一 个 思 维 方 法 都 要 看 好 几 遍 才 能 深 刻 理 解 ， 对 日 常 生 活 和 工 作 都 有 帮 助 ， 如 果 真 的 弄 懂 了 能 让 思 维 水 平 上 升 一 个 层 次 。 这 本 书 乍 一 看 挺 抽 象 ， 其 实 你 要 循 着 作 者 的 脉 络 和 自 己 的 思 考 去 读 ， 是 很 有 分 量 的 。 能 当 工 具 书 ， 就 是 带 在 身 边 的 一 本 书 ， 没 事 就 看 一 看 ， 能 让 你 不 断 地 提 高 自 己 。 我 自 己 买 了 四 本 ， 自 己 留 了 两 本 ， 另 外 的 两 本 送 给 了 原 来 的 同 事 。\n",
      "label= 1\n",
      "predict= 1\n",
      "比 较 实 用 的 机 器 ， 加 内 存 时 一 看 背 后 没 有 能 开 的 小 盖 子 ， 应 该 是 在 掌 托 下 面 ， 还 是 打 了 8009908888 问 了 一 下 。 拆 开 看 了 就 会 知 道 散 热 做 的 相 当 好 ， 机 器 用 料 很 到 位 ， 相 当 结 实 。\n",
      "label= 1\n",
      "predict= 1\n",
      "我 怎 么 感 觉 p8600 和 另 外 一 个 本 联 想 f31 的 t2390 一 样 的 。 。 。 没 感 觉 出 啥 区 别 来\n",
      "label= 0\n",
      "predict= 0\n",
      "房 间 很 好 ， 设 施 也 很 新 ， 总 的 来 说 还 是 很 满 意 的 。 周 围 的 很 方 便 ， 楼 下 隔 壁 都 有 商 场 。\n",
      "label= 1\n",
      "predict= 1\n",
      "很 超 值 的 酒 店 ， 对 得 起 这 个 价 格 ， 但 是 上 网 步 骤 比 较 麻 烦 ， 网 速 一 般 。\n",
      "label= 1\n",
      "predict= 1\n",
      "配 置 高 ， 外 观 漂 亮 ， 画 质 细 腻 触 摸 板 做 了 凸 点 处 理, 手 感 也 不 错 白 色 键 盘 抢 眼\n",
      "label= 1\n",
      "predict= 1\n"
     ]
    }
   ],
   "source": [
    "#第6章/测试\n",
    "model.eval()\n",
    "\n",
    "for i, data in enumerate(trainer.get_eval_dataloader()):\n",
    "    break\n",
    "\n",
    "for k, v in data.items():\n",
    "    data[k] = v.to('cuda')\n",
    "\n",
    "out = model(**data)\n",
    "out = out['logits'].argmax(dim=1)\n",
    "\n",
    "for i in range(16):\n",
    "    print(tokenizer.decode(data['input_ids'][i], skip_special_tokens=True))\n",
    "    print('label=', data['labels'][i].item())\n",
    "    print('predict=', out[i].item())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
